{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize image-specific class saliency with backpropagation\n",
    "\n",
    "---\n",
    "\n",
    "A quick demo of creating saliency maps for CNNs using [FlashTorch 🔦](https://github.com/MisaOgura/flashtorch).\n",
    "\n",
    "\n",
    "❗This notebook is for those who are using this notebook in **Google Colab**.\n",
    "\n",
    "If you aren't on Google Colab already, please head to the Colab version of this notebook **[here](https://colab.research.google.com/github/MisaOgura/flashtorch/blob/master/examples/visualise_saliency_with_backprop_colab.ipynb)** to execute.\n",
    "\n",
    "---\n",
    "\n",
    "The gradients obtained can be used to visualise an image-specific class saliency map, which can gives some intuition on regions within the input image that contribute the most (and least) to the corresponding output.\n",
    "\n",
    "More details on saliency maps: [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/pdf/1312.6034.pdf)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. Set up\n",
    "\n",
    "A GPU runtime is available on Colab for free, from the `Runtime` tab on the top menu bar.\n",
    "\n",
    "It is **recommended to use GPU** as a runtime for the enhanced speed of computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install flashtorch\n",
    "\n",
    "!pip install flashtorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download example images\n",
    "\n",
    "!mkdir -p images\n",
    "\n",
    "!wget -nv \\\n",
    "    https://github.com/MisaOgura/flashtorch/raw/master/examples/images/great_grey_owl.jpg \\\n",
    "    https://github.com/MisaOgura/flashtorch/raw/master/examples/images/peacock.jpg        \\\n",
    "    https://github.com/MisaOgura/flashtorch/raw/master/examples/images/toucan.jpg         \\\n",
    "    -P /content/images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.models as models\n",
    "\n",
    "from flashtorch.utils import apply_transforms, load_image\n",
    "from flashtorch.saliency import Backprop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Load an image "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = load_image('/content/images/great_grey_owl.jpg')\n",
    "\n",
    "plt.imshow(image)\n",
    "plt.title('Original image')\n",
    "plt.axis('off');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Load a pre-trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.alexnet(pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Create an instance of Backprop with the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "backprop = Backprop(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Visualize saliency maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform the input image to a tensor\n",
    "\n",
    "owl = apply_transforms(image)\n",
    "\n",
    "# Set a target class from ImageNet task: 24 in case of great gray owl\n",
    "\n",
    "target_class = 24\n",
    "\n",
    "# Ready to roll!\n",
    "\n",
    "backprop.visualize(owl, target_class, guided=True, use_gpu=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. What about other birds?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What makes peacock a peacock...?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "peacock = apply_transforms(load_image('/content/images/peacock.jpg'))\n",
    "backprop.visualize(peacock, 84, guided=True, use_gpu=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Or a toucan?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "toucan = apply_transforms(load_image('/content/images/toucan.jpg'))\n",
    "backprop.visualize(toucan, 96, guided=True, use_gpu=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please try out other models/images too!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
